# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hysop.backend.device.opencl.opencl_types import (
basetype as cl_basetype,
components as cl_components,
vtype as cl_vtype,
)
from hysop.backend.device.codegen.base.variables import ctype_to_dtype
import sympy as sm
from hysop.symbolic import Symbol, Expr
from hysop.symbolic.array import OpenClSymbolicBuffer, OpenClSymbolicNdBuffer
from hysop.tools.htypes import check_instance, first_not_None, to_tuple, to_list
from hysop.tools.numerics import is_fp, is_signed, is_unsigned, is_integer, is_complex
from packaging import version
if version.parse(sm.__version__) > version.parse("1.7"):
from sympy.printing.c import C99CodePrinter
else:
from sympy.printing.ccode import C99CodePrinter
InstructionTermination = ""
[docs]
class TypedI:
def __new__(cls, *args, **kwds):
positive = kwds.pop("positive", None)
obj = super().__new__(cls, *args, **kwds)
obj.positive = positive
return obj
[docs]
@classmethod
def vtype(cls, btype, n):
return cl_vtype(btype, n)
@property
def btype(self):
return cl_basetype(self.ctype)
@property
def basetype(self):
return self.btype
@property
def components(self):
return cl_components(self.ctype)
@property
def dtype(self):
return ctype_to_dtype(self.btype)
@property
def is_signed(self):
return is_signed(self.dtype)
@property
def is_unsigned(self):
return is_unsigned(self.dtype)
@property
def is_integer(self):
return is_integer(self.dtype)
@property
def is_fp(self):
return is_fp(self.dtype)
@property
def is_complex(self):
raise NotImplementedError()
@property
def is_positive(self):
return first_not_None(self._positive, self.is_unsigned)
[docs]
class TypedSymbol(TypedI, Symbol):
def __new__(cls, ctype, **kwds):
obj = super().__new__(cls, **kwds)
obj.ctype = ctype
return obj
[docs]
class TypedExpr(TypedI, Expr):
def __new__(cls, ctype, *args):
try:
obj = super().__new__(cls, ctype, *args, evaluate=False)
except TypeError:
obj = super().__new__(cls, ctype, *args)
check_instance(ctype, str)
obj.ctype = ctype
return obj
[docs]
class TypedExprWrapper(TypedExpr):
def __new__(cls, ctype, expr):
obj = super().__new__(cls, ctype, expr)
obj.expr = expr
return obj
def _ccode(self, printer):
return printer._print(self.expr)
[docs]
class OpenClConvert(TypedExpr):
def __new__(cls, ctype, expr):
obj = super().__new__(cls, ctype, expr)
obj.expr = expr
return obj
def _ccode(self, printer):
val = printer._print(self.expr)
cast = f"convert_{self.ctype}({val})"
return cast
[docs]
class OpenClCast(TypedExpr):
def __new__(cls, ctype, expr):
obj = super().__new__(cls, ctype, expr)
obj.expr = expr
return obj
def _ccode(self, printer):
expr = printer._print(self.expr)
cast = f"({self.ctype})({expr})"
return cast
[docs]
class OpenClBool(TypedExpr):
"""
Convert a scalar boolean condition (ie. a int in OpenCL)
to a compatible vector boolean condition (ie. all bits set)
prior to vectorization. Also force min integer rank.
"""
def __new__(cls, expr):
assert expr.ctype in ("short", "int", "long"), ctype
ctype = "char" # force lowest integer rank (to force promotion later)
obj = super().__new__(cls, ctype, expr)
obj.expr = expr
return obj
def _ccode(self, printer):
# negate scalar boolean to set all bits to 1 (unsigned -1 sets all bits)
# (unsigned 0 has not bit set)
expr = printer._print(self.expr)
# pre-promote result to maximal rank just in case opencl
# implementation or runtime fails to yield good type or if
# further promotion is needed after.
s = f"(-({expr}))"
# this breaks conditionals if further promotion is needed
# s = '(u{})({})'.format(self.ctype, s)
return s
[docs]
class Return(Expr):
def __new__(cls, expr):
obj = super().__new__(cls, expr)
obj.expr = expr
return obj
def _ccode(self, printer):
expr = printer._print(self.expr)
code = f"return {expr};"
ret = printer.codegen.append(code)
return InstructionTermination
[docs]
class NumericalConstant(TypedExpr):
def __new__(cls, ctype, value):
obj = super().__new__(cls, ctype, value)
obj.value = value
return obj
def _ccode(self, printer):
return printer.typegen.dump(self.value)
[docs]
@classmethod
def build(cls, val, typegen):
ctype = typegen.dumped_type(val)
return cls(ctype, val)
[docs]
class IntegerConstant(NumericalConstant):
pass
[docs]
class FloatingPointConstant(NumericalConstant):
pass
[docs]
class ComplexFloatingPointConstant(NumericalConstant):
def _ccode(self, printer):
return "(({})({}, {}))".format(
self.ctype,
printer.typegen.dump(self.value.real),
printer.typegen.dump(self.value.imag),
)
[docs]
class OpenClVariable(TypedExpr):
def __new__(cls, ctype, var, *args):
obj = super().__new__(cls, ctype, var, *args)
obj.var = var
return obj
def _ccode(self, printer):
return self.var()
[docs]
class OpenClIndexedVariable(OpenClVariable):
def __new__(cls, ctype, var, index):
try:
dim = index.var.dim
components = cl_components(ctype)
ctype = cls.vtype(cl_basetype(ctype), components * dim)
except AttributeError as e:
dim = 1
obj = super().__new__(cls, ctype, var, index)
obj.index = index
obj.dim = dim
return obj
def _ccode(self, printer):
if not isinstance(self.var, (OpenClSymbolicBuffer, OpenClSymbolicNdBuffer)):
try:
return self.var[self.index]
except Exception as e:
pass
var = printer._print(self.var)
if self.dim > 1:
vals = ", ".join(f"{var}[{self.index.var[i]}]" for i in range(self.dim))
return f"({self.ctype})({vals})"
else:
index = printer._print(self.index)
return f"{var}[{index}]"
[docs]
class OpenClAssignment(TypedExpr):
def __new__(cls, ctype, var, op, rhs):
obj = super().__new__(cls, ctype, var, op, rhs)
obj.var = var
obj.op = op
obj.rhs = rhs
return obj
def _ccode(self, printer):
var = printer._print(self.var)
rhs = printer._print(self.rhs)
code = f"{var} {self.op} {rhs};"
printer.codegen.append(code)
return InstructionTermination
[docs]
class FunctionCall(TypedExpr):
def __new__(cls, ctype, fn, fn_kwds):
obj = super().__new__(cls, ctype, fn, fn_kwds)
obj.fn = fn
obj.fn_kwds = fn_kwds
return obj
def _ccode(self, printer):
return self.fn(**self.fn_kwds)
def _sympystr(self, printer):
return f"FunctionCall({self.fn.name})"
[docs]
class VStore(Expr):
def __new__(cls, ptr, offset, data, n=1, **opts):
obj = super().__new__(cls, ptr, offset, data, n)
obj.ptr = ptr
obj.offset = offset
obj.data = data
obj.n = n
obj.opts = opts
return obj
def _ccode(self, printer):
code = printer.codegen.vstore(
n=self.n, ptr=self.ptr, offset=self.offset, data=self.data, **self.opts
)
printer.codegen.append(code)
return InstructionTermination
[docs]
class VStoreIf(VStore):
def __new__(cls, cond, scalar_cond, ptr, offset, data, n, **opts):
obj = super().__new__(cls, ptr, offset, data, n)
obj.cond = cond
obj.scalar_cond = scalar_cond
obj.opts = opts
return obj
def _ccode(self, printer):
printer.codegen.vstore_if(
cond=self.cond,
scalar_cond=self.scalar_cond,
n=self.n,
ptr=self.ptr,
offset=self.offset,
data=self.data,
**self.opts,
)
return InstructionTermination
[docs]
class VLoad(TypedExpr):
def __new__(cls, ctype, ptr, offset, dst=None, n=1, **opts):
obj = super().__new__(cls, ctype, ptr, offset, dst, n)
obj.ptr = ptr
obj.offset = offset
obj.dst = dst
obj.n = n
obj.opts = opts
return obj
def _ccode(self, printer):
vload = printer.codegen.vload(
n=self.n, ptr=self.ptr, offset=self.offset, **self.opts
)
if self.dst:
self.dst.affect(printer.codegen, vload)
return InstructionTermination
else:
return vload
[docs]
class VLoadIf(VLoad):
def __new__(cls, cond, scalar_cond, ptr, offset, dst, n, default_value, **opts):
obj = super().__new__(cls, ptr, offset, dst, n)
obj.cond = cond
obj.scalar_cond = scalar_cond
obj.default_value = default_value
obj.opts = opts
return obj
def _ccode(self, printer):
printer.codegen.vload_if(
cond=self.cond,
scalar_cond=self.scalar_cond,
n=self.n,
ptr=self.ptr,
offset=self.offset,
dst=self.dst,
default_value=self.default_value,
**self.opts,
)
return InstructionTermination
[docs]
class IfElse(Expr):
def __new__(cls, conditions, all_exprs, else_exprs=None):
conditions = to_tuple(conditions)
all_exprs = to_list(all_exprs)
else_exprs = to_list(else_exprs) if (else_exprs is not None) else None
assert len(all_exprs) >= 1
if not isinstance(all_exprs[0], list):
assert len(conditions) == 1
all_exprs = [all_exprs]
assert len(conditions) == len(all_exprs) >= 1
obj = super().__new__(cls, conditions, all_exprs, else_exprs)
obj.conditions = conditions
obj.all_exprs = all_exprs
obj.else_exprs = else_exprs
return obj
def _ccode(self, printer):
codegen = printer.codegen
for cond, exprs in zip(self.conditions, self.all_exprs):
with codegen._if_(cond):
for e in exprs:
printer._print(e)
if self.else_exprs:
with codegen._else_():
for e in self.else_exprs:
printer._print(e)
return InstructionTermination
[docs]
class UpdateVars(Expr):
def __new__(cls, srcs, dsts, ghosts):
obj = super().__new__(cls, srcs, dsts, ghosts)
assert srcs and dsts
obj.srcs = srcs
obj.dsts = dsts
obj.init(srcs, dsts, ghosts)
return obj
[docs]
def init(self, srcs, dsts, ghosts):
assert len(srcs) == len(dsts)
private_stores = ()
local_stores = ()
for src, dst, ghost in zip(srcs, dsts, ghosts):
assert not src.is_ptr
if dst.is_ptr:
assert dst.storage == "__local"
local_stores += ((src, dst, ghost),)
else:
private_stores += ((src, dst),)
self.private_stores = private_stores
self.local_stores = local_stores
def _ccode(self, printer):
codegen = printer.codegen
codegen.jumpline()
csc = codegen.csc
codegen.comment(
"Updating {} from {}".format(
", ".join(x() for x in self.dsts), ", ".join(x() for x in self.srcs)
)
)
if self.local_stores:
codegen.barrier(_local=True)
if self.private_stores:
with codegen._align_() as al:
for src, dst in self.private_stores:
dst.affect(al, init=src, align=True)
if self.local_stores:
srcs = tuple(map(lambda x: x[0], self.local_stores))
ptrs = tuple(map(lambda x: x[1], self.local_stores))
offsets = tuple(map(lambda x: x[2], self.local_stores))
codegen.multi_vstore_if(
csc.is_last_active,
lambda i: f"{csc.full_offset}+{i} < {csc.compute_grid_size[0]}",
csc.vectorization,
csc.local_offset,
srcs,
ptrs,
extra_offsets=offsets,
use_short_circuit=csc.use_short_circuit,
else_cond=csc.is_active,
)
codegen.barrier(_local=True)
return InstructionTermination
[docs]
class BuiltinFunctionCall(TypedExpr):
def __new__(cls, ctype, fname, *fargs):
obj = super().__new__(cls, ctype, fname, fargs)
obj.fname = fname
obj.fargs = fargs
return obj
def _ccode(self, printer):
return "{}({})".format(
self.fname, ", ".join(printer._print(arg) for arg in self.fargs)
)
[docs]
class BuiltinFunction:
def __new__(cls, fname):
obj = super().__new__(cls)
obj.fname = fname
return obj
def __call__(self, ctype, *args):
return BuiltinFunctionCall(ctype, self.fname, *args)
[docs]
class OpenClPrinter(C99CodePrinter):
_default_settings = {
"order": None,
"full_prec": "auto",
"precision": 15,
"user_functions": {},
"human": True,
"contract": True,
"dereference": set(),
"error_on_reserved": False,
"reserved_word_suffix": "_",
}
def __init__(self, typegen, codegen, settings={}, **kwds):
super().__init__(settings=settings, **kwds)
self.typegen = typegen
self.codegen = codegen
def _handle_UnevaluatedExpr(self, expr):
return expr
[docs]
def doprint(self, expr, terminate=True):
res = super().doprint(expr)
if terminate and (res != InstructionTermination):
msg = (
"OpenClPrinter failed to generate code for the following expression:\n"
)
msg += f" {expr}\n"
msg += f"Returned value was:\n {res}\n"
raise RuntimeError(msg)
if not terminate:
return res